This is an illustration of the work by Julia Silge and Allison Horst.
Here I used a Repeated k-fold Cross validation to see if there is any improvement or difference in performance of the model in repeated folds.
I also illustrate how to tune threshold for logistic regression in order to control sensitivity and specificity of the fitted model.
Palmer Penguins
In a k-fold cross validation, the data is Randomly split at 5 folds, if k=5. Then we iteratively fit model in 4 train folds, and test the model on the test fold. We do this 5 times and record the performance of the model. Then we take average of performance. This way we can get a robust estimate of the performance.
Now, the data is split randomly at first, we don’t know if this split results in maximum randomness. So to get even more robust / accurate estimate we repeat this 5-fold validation several times, and take average over all repetition
5-fold Cross Validation
Repeated 5-fold Cross Validation
if (!require('palmerpenguins')) devtools::install_github("allisonhorst/palmerpenguins"); library('palmerpenguins')
require(ggplot2)
require(tidymodels)penguins## # A tibble: 344 x 8
## species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
## <fct> <fct> <dbl> <dbl> <int> <int>
## 1 Adelie Torgersen 39.1 18.7 181 3750
## 2 Adelie Torgersen 39.5 17.4 186 3800
## 3 Adelie Torgersen 40.3 18 195 3250
## 4 Adelie Torgersen NA NA NA NA
## 5 Adelie Torgersen 36.7 19.3 193 3450
## 6 Adelie Torgersen 39.3 20.6 190 3650
## 7 Adelie Torgersen 38.9 17.8 181 3625
## 8 Adelie Torgersen 39.2 19.6 195 4675
## 9 Adelie Torgersen 34.1 18.1 193 3475
## 10 Adelie Torgersen 42 20.2 190 4250
## # ... with 334 more rows, and 2 more variables: sex <fct>, year <int>
glimpse(penguins)## Rows: 344
## Columns: 8
## $ species <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Adel~
## $ island <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgerse~
## $ bill_length_mm <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39.2, 34.1, ~
## $ bill_depth_mm <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19.6, 18.1, ~
## $ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193, 190, 186~
## $ body_mass_g <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 4675, 3475, ~
## $ sex <fct> male, female, female, NA, female, male, female, male~
## $ year <int> 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007~
Today we will build a logistic regression model to predict the Gender of the palmer penguin, using body part lengths and body mass data.
Plot the data by flipper length and bill length of the penguins, also use body mass in the bubble plot to see for any relationship with the gender of the penguins.
penguins %>%
filter(!is.na(sex)) %>%
ggplot(aes(flipper_length_mm, bill_length_mm, color = sex, size = body_mass_g)) +
geom_point(alpha = 0.5) +
facet_wrap(~species)penguins %>%
drop_na() %>%
select(species, body_mass_g, ends_with("_mm"), sex) %>%
GGally::ggpairs(aes(color = sex, alpha = 0.5))Looks like there is relation between sex, body part lengths and body mass of the penguins.
Body Parts of Penguins
We drop year and island from our data. Also drop missing observations in sex
penguins_df <- penguins %>%
filter(!is.na(sex)) %>%
select(-year, -island)
penguins_df## # A tibble: 333 x 6
## species bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
## <fct> <dbl> <dbl> <int> <int> <fct>
## 1 Adelie 39.1 18.7 181 3750 male
## 2 Adelie 39.5 17.4 186 3800 female
## 3 Adelie 40.3 18 195 3250 female
## 4 Adelie 36.7 19.3 193 3450 female
## 5 Adelie 39.3 20.6 190 3650 male
## 6 Adelie 38.9 17.8 181 3625 female
## 7 Adelie 39.2 19.6 195 4675 male
## 8 Adelie 41.1 17.6 182 3200 female
## 9 Adelie 38.6 21.2 191 3800 male
## 10 Adelie 34.6 21.1 198 4400 male
## # ... with 323 more rows
levels(penguins_df$sex)## [1] "female" "male"
We use tidymodels package for modelling gender of the palmer penguins.
set.seed(123)
penguin_split <- initial_split(penguins_df, strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
penguin_split## <Analysis/Assess/Total>
## <250/83/333>
penguin_cv <- vfold_cv(data = penguin_train, v = 10, repeats = 10, strata = sex)
penguin_cv## # 10-fold cross-validation repeated 10 times using stratification
## # A tibble: 100 x 3
## splits id id2
## <list> <chr> <chr>
## 1 <split [224/26]> Repeat01 Fold01
## 2 <split [224/26]> Repeat01 Fold02
## 3 <split [224/26]> Repeat01 Fold03
## 4 <split [224/26]> Repeat01 Fold04
## 5 <split [225/25]> Repeat01 Fold05
## 6 <split [225/25]> Repeat01 Fold06
## 7 <split [226/24]> Repeat01 Fold07
## 8 <split [226/24]> Repeat01 Fold08
## 9 <split [226/24]> Repeat01 Fold09
## 10 <split [226/24]> Repeat01 Fold10
## # ... with 90 more rows
glm_spec <- logistic_reg() %>%
set_engine("glm")penguin_wf <- workflow() %>%
add_formula(sex ~ .)fit_resamples fits the logistic model in each of the 100 training datasets in the penguin_cv set, and evaluates the model on each of the 100 testing datasets. It also saves the predictions for evaluating performance of the model on each dataset.
### Parallel Processing makes things faster
### tidymodels support parallel processing
doParallel::registerDoParallel()
glm_rs <- penguin_wf %>%
add_model(glm_spec) %>%
fit_resamples(
resamples = penguin_cv,
control = control_resamples(save_pred = TRUE, verbose = TRUE)
)
glm_rs## # Resampling results
## # 10-fold cross-validation repeated 10 times using stratification
## # A tibble: 100 x 6
## splits id id2 .metrics .notes .predictions
## <list> <chr> <chr> <list> <list> <list>
## 1 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
## 2 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
## 3 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
## 4 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
## 5 <split [225/25~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [25 x 6~
## 6 <split [225/25~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [25 x 6~
## 7 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
## 8 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
## 9 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
## 10 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
## # ... with 90 more rows
This accuracy and AUC is mean over all CV dataset.
collect_metrics(glm_rs)## # A tibble: 2 x 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.900 100 0.00589 Preprocessor1_Model1
## 2 roc_auc binary 0.965 100 0.00306 Preprocessor1_Model1
glm_rs %>%
unnest(.metrics) %>%
ggplot(aes(id2, .estimate, color = .metric)) +
geom_point() +
labs(title = "Accurary and ACU over Folds and Repetitions",
x = "Fold",
y = NULL,
color = "Metric") +
facet_wrap(.metric ~ id) +
theme(axis.text.x = element_text(size=6, angle = 90)) Also showing average numbers in the confusion matrix.
glm_rs %>%
conf_mat_resampled()## # A tibble: 4 x 3
## Prediction Truth Freq
## <fct> <fct> <dbl>
## 1 female female 112.
## 2 female male 12.7
## 3 male female 12.2
## 4 male male 113.
The ROC curve shows similar performance over repeats, although some variation is seen over the folds within repeats.
glm_rs %>%
collect_predictions() %>%
group_by(id, id2) %>%
roc_curve(sex, .pred_female) %>%
ggplot(aes(1 - specificity, sensitivity, color = id2)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_path(show.legend = TRUE, alpha = 0.5, size = 0.8) +
coord_equal() +
facet_wrap(~id) +
labs(color='Fold', x = "1 - Specificity", y = "Sensitivity", title = "ROC & AUC by Fold and Repeat") +
theme_minimal()penguin_final <- penguin_wf %>%
add_model(glm_spec) %>%
last_fit(penguin_split)
penguin_final## # Resampling results
## # Manual resampling
## # A tibble: 1 x 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [250~ train/test ~ <tibble [2 x~ <tibble [0 ~ <tibble [83 x ~ <workflo~
The metrics on testing data shows similar performance with the CV data. This indicated absence of overfitting, and good predictive performance of the logistic model for new data.
collect_metrics(penguin_final)## # A tibble: 2 x 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.940 Preprocessor1_Model1
## 2 roc_auc binary 0.991 Preprocessor1_Model1
collect_predictions(penguin_final) %>%
conf_mat(sex, .pred_class)## Truth
## Prediction female male
## female 39 3
## male 2 39
collect_predictions(penguin_final) %>%
sensitivity(sex, .pred_class)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.951
collect_predictions(penguin_final) %>%
specificity(sex, .pred_class)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 spec binary 0.929
collect_predictions(penguin_final) %>%
precision(sex, .pred_class)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 precision binary 0.929
The roc_curve function constructs the full ROC curve using threshold values and returns a tibble.
### First collect predictions on the test data from the model
penguin_final %>%
collect_predictions()## # A tibble: 83 x 7
## id .pred_female .pred_male .row .pred_class sex .config
## <chr> <dbl> <dbl> <int> <fct> <fct> <chr>
## 1 train/test s~ 0.0117 0.988 9 male male Preprocessor1_~
## 2 train/test s~ 0.499 0.501 12 male fema~ Preprocessor1_~
## 3 train/test s~ 0.0000458 1.00 13 male male Preprocessor1_~
## 4 train/test s~ 0.985 0.0155 24 female fema~ Preprocessor1_~
## 5 train/test s~ 0.992 0.00754 26 female fema~ Preprocessor1_~
## 6 train/test s~ 0.729 0.271 27 female male Preprocessor1_~
## 7 train/test s~ 0.0262 0.974 32 male male Preprocessor1_~
## 8 train/test s~ 0.415 0.585 42 male male Preprocessor1_~
## 9 train/test s~ 0.000236 1.00 44 male male Preprocessor1_~
## 10 train/test s~ 0.835 0.165 45 female fema~ Preprocessor1_~
## # ... with 73 more rows
### Then construct roc curve with these predictions
penguin_final %>%
collect_predictions() %>%
roc_curve(sex, .pred_female)## # A tibble: 85 x 3
## .threshold specificity sensitivity
## <dbl> <dbl> <dbl>
## 1 -Inf 0 1
## 2 1.24e-7 0 1
## 3 1.24e-5 0.0238 1
## 4 1.78e-5 0.0476 1
## 5 4.58e-5 0.0714 1
## 6 7.60e-5 0.0952 1
## 7 1.12e-4 0.119 1
## 8 1.18e-4 0.143 1
## 9 1.49e-4 0.167 1
## 10 2.36e-4 0.190 1
## # ... with 75 more rows
Using this tibble we can make the ROC curve and find optimal threshold for classifying female penguins.
### Hover your mouse over this plot
r <- penguin_final %>%
collect_predictions() %>%
roc_curve(sex, .pred_female) %>%
ggplot(aes(1 - specificity, sensitivity)) +
geom_point(size = 0.2, aes(color = .threshold)) +
geom_abline(lty = 2,
color = "gray80",
size = 1.5) +
geom_path(show.legend = TRUE,
alpha = 0.3,
size = 0.5) +
geom_text(aes(label = round(.threshold, 2)),
size = 2.5,
vjust = -0.5,
fontface = "bold")
plotly::ggplotly(r)The ROC curve shows threshold = 0.75 gives 100% Specificity. We can change the predictions made by the model by changing the threshold value to 0.75, in order to predict males more accurately, with a little error to predicting females.
We can use the probably package to change the threshold and make new predictions. Check HERE for more !
### We need probably package
library(probably)
### set threshold
thresh <- 0.75
### For more information,
### run ?roc_curve
### run ?make_two_class_pred
### Mutate .pred_class with new threshold
new_preds <- penguin_final %>%
collect_predictions() %>%
### mutate .pred_class with new threshold
mutate(.pred_class = make_two_class_pred(.pred_female, ### Predicted Probability
levels(sex),
threshold = thresh), ### Threshold
.pred_class = factor(.pred_class, levels = levels(sex)))
### With New Threshold, Performance on Test Data
### Confusion Matrix
new_preds %>%
conf_mat(sex, .pred_class)## Truth
## Prediction female male
## female 38 0
## male 3 42
### Sensitivity
new_preds %>%
sensitivity(sex, .pred_class)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.927
### Specificity
new_preds %>%
specificity(sex, .pred_class)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 spec binary 1
### Precision
new_preds %>%
precision(sex, .pred_class)## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 precision binary 1
We can see the new threshold performs better in terms of specificity. This could be useful if the modeler (say a biologist) need to classify males more accurately than females for scientific purposes. We can also set threshold = 0.17 for 100% sensitivity, if we want to classify females more accurately than males.
By default tidymodels predicts with threshold = 0.50 for logistic regression.
Looks like bill depth and bill length have highest importance in predicting gender of the penguins. These two variables separate the penguins by gender most.
penguin_final$.workflow[[1]] %>%
tidy(exponentiate = TRUE)## # A tibble: 7 x 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) 3.12e-35 13.5 -5.90 0.00000000369
## 2 speciesChinstrap 1.34e- 3 1.70 -3.89 0.000101
## 3 speciesGentoo 1.08e- 4 2.89 -3.16 0.00159
## 4 bill_length_mm 1.78e+ 0 0.137 4.20 0.0000268
## 5 bill_depth_mm 3.89e+ 0 0.373 3.64 0.000273
## 6 flipper_length_mm 1.07e+ 0 0.0538 1.31 0.189
## 7 body_mass_g 1.01e+ 0 0.00108 4.70 0.00000260
penguin_final$.workflow[[1]] %>%
tidy(exponentiate = TRUE) %>%
select(term, estimate) %>%
mutate(term = as.factor(term)) %>%
ggplot(aes(reorder(term, estimate), estimate,
fill =term)) +
geom_bar(stat = "identity", show.legend = FALSE, width = 0.7) +
labs(title = "Increase in odds of penguin being Female by one unit increase in each variable",
x = "Variable", y = "Odds increase by Times") +
geom_text(aes(label = round(estimate, 3)),
nudge_y = 0.15 ,
size = 4.5,
colour = 'black',
fontface = 'bold') +
coord_flip() +
theme_light()1mm increase in bill depth increases the odds of the penguin being female by almost 4 times. Species don’t seem to affect the prediction of gender of the penguins much.
penguins %>%
filter(!is.na(sex)) %>%
ggplot(aes(bill_depth_mm, bill_length_mm, color = sex, size = body_mass_g)) +
geom_point(alpha = 0.5) +
facet_wrap(~species)Bill Length and Bill Depth